{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### XGBoost"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from cart import TreeNode, BinaryDecisionTree\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.metrics import accuracy_score\n",
    "from utils import cat_label_convert"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "### XGBoost单棵树类\n",
    "class XGBoost_Single_Tree(BinaryDecisionTree):\n",
    "    # 结点分裂方法\n",
    "    def node_split(self, y):\n",
    "        # 中间特征所在列\n",
    "        feature = int(np.shape(y)[1]/2)\n",
    "        # 左子树为真实值，右子树为预测值\n",
    "        y_true, y_pred = y[:, :feature], y[:, feature:]\n",
    "        return y_true, y_pred\n",
    "\n",
    "    # 信息增益计算方法\n",
    "    def gain(self, y, y_pred):\n",
    "        # 梯度计算\n",
    "        Gradient = np.power((y * self.loss.gradient(y, y_pred)).sum(), 2)\n",
    "        # Hessian矩阵计算\n",
    "        Hessian = self.loss.hess(y, y_pred).sum()\n",
    "        return 0.5 * (Gradient / Hessian)\n",
    "\n",
    "    # 树分裂增益计算\n",
    "    # 式(12.28)\n",
    "    def gain_xgb(self, y, y1, y2):\n",
    "        # 结点分裂\n",
    "        y_true, y_pred = self.node_split(y)\n",
    "        y1, y1_pred = self.node_split(y1)\n",
    "        y2, y2_pred = self.node_split(y2)\n",
    "        true_gain = self.gain(y1, y1_pred)\n",
    "        false_gain = self.gain(y2, y2_pred)\n",
    "        gain = self.gain(y_true, y_pred)\n",
    "        return true_gain + false_gain - gain\n",
    "\n",
    "    # 计算叶子结点最优权重\n",
    "    def leaf_weight(self, y):\n",
    "        y_true, y_pred = self.node_split(y)\n",
    "        # 梯度计算\n",
    "        gradient = np.sum(y_true * self.loss.gradient(y_true, y_pred), axis=0)\n",
    "        # hessian矩阵计算\n",
    "        hessian = np.sum(self.loss.hess(y_true, y_pred), axis=0)\n",
    "        # 叶子结点得分\n",
    "        leaf_weight =  gradient / hessian\n",
    "        return leaf_weight\n",
    "\n",
    "    # 树拟合方法\n",
    "    def fit(self, X, y):\n",
    "        self.impurity_calculation = self.gain_xgb\n",
    "        self._leaf_value_calculation = self.leaf_weight\n",
    "        super(XGBoost_Single_Tree, self).fit(X, y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "### 分类损失函数定义\n",
    "# 定义Sigmoid类\n",
    "class Sigmoid:\n",
    "    def __call__(self, x):\n",
    "        return 1 / (1 + np.exp(-x))\n",
    "\n",
    "    def gradient(self, x):\n",
    "        return self.__call__(x) * (1 - self.__call__(x))\n",
    "\n",
    "# 定义Logit损失\n",
    "class LogisticLoss:\n",
    "    def __init__(self):\n",
    "        sigmoid = Sigmoid()\n",
    "        self._func = sigmoid\n",
    "        self._grad = sigmoid.gradient\n",
    "    \n",
    "    # 定义损失函数形式\n",
    "    def loss(self, y, y_pred):\n",
    "        y_pred = np.clip(y_pred, 1e-15, 1 - 1e-15)\n",
    "        p = self._func(y_pred)\n",
    "        return y * np.log(p) + (1 - y) * np.log(1 - p)\n",
    "\n",
    "    # 定义一阶梯度\n",
    "    def gradient(self, y, y_pred):\n",
    "        p = self._func(y_pred)\n",
    "        return -(y - p)\n",
    "\n",
    "    # 定义二阶梯度\n",
    "    def hess(self, y, y_pred):\n",
    "        p = self._func(y_pred)\n",
    "        return p * (1 - p)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "### XGBoost定义\n",
    "class XGBoost:\n",
    "    def __init__(self, n_estimators=300, learning_rate=0.001, \n",
    "                 min_samples_split=2,\n",
    "                 min_gini_impurity=999, \n",
    "                 max_depth=2):\n",
    "        # 树的棵树\n",
    "        self.n_estimators = n_estimators\n",
    "        # 学习率\n",
    "        self.learning_rate = learning_rate \n",
    "        # 结点分裂最小样本数\n",
    "        self.min_samples_split = min_samples_split \n",
    "        # 结点最小基尼不纯度\n",
    "        self.min_gini_impurity = min_gini_impurity  \n",
    "        # 树最大深度\n",
    "        self.max_depth = max_depth                  \n",
    "        # 用于分类的对数损失\n",
    "        # 回归任务可定义平方损失 \n",
    "        # self.loss = SquaresLoss()\n",
    "        self.loss = LogisticLoss()\n",
    "        # 初始化分类树列表\n",
    "        self.trees = []\n",
    "        # 遍历构造每一棵决策树\n",
    "        for _ in range(n_estimators):\n",
    "            tree = XGBoost_Single_Tree(\n",
    "                    min_samples_split=self.min_samples_split,\n",
    "                    min_gini_impurity=self.min_gini_impurity,\n",
    "                    max_depth=self.max_depth,\n",
    "                    loss=self.loss)\n",
    "            self.trees.append(tree)\n",
    "    \n",
    "    # xgboost拟合方法\n",
    "    def fit(self, X, y):\n",
    "        y = cat_label_convert(y)\n",
    "        y_pred = np.zeros(np.shape(y))\n",
    "        # 拟合每一棵树后进行结果累加\n",
    "        for i in range(self.n_estimators):\n",
    "            tree = self.trees[i]\n",
    "            y_true_pred = np.concatenate((y, y_pred), axis=1)\n",
    "            tree.fit(X, y_true_pred)\n",
    "            iter_pred = tree.predict(X)\n",
    "            y_pred -= np.multiply(self.learning_rate, iter_pred)\n",
    "\n",
    "    # xgboost预测方法\n",
    "    def predict(self, X):\n",
    "        y_pred = None\n",
    "        # 遍历预测\n",
    "        for tree in self.trees:\n",
    "            iter_pred = tree.predict(X)\n",
    "            if y_pred is None:\n",
    "                y_pred = np.zeros_like(iter_pred)\n",
    "            y_pred -= np.multiply(self.learning_rate, iter_pred)\n",
    "        y_pred = np.exp(y_pred) / np.sum(np.exp(y_pred), axis=1, keepdims=True)\n",
    "        # 将概率预测转换为标签\n",
    "        y_pred = np.argmax(y_pred, axis=1)\n",
    "        return y_pred"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy:  0.9333333333333333\n"
     ]
    }
   ],
   "source": [
    "from sklearn import datasets\n",
    "# 导入鸢尾花数据集\n",
    "data = datasets.load_iris()\n",
    "# 获取输入输出\n",
    "X, y = data.data, data.target\n",
    "# 数据集划分\n",
    "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=43)  \n",
    "# 创建xgboost分类器\n",
    "clf = XGBoost()\n",
    "# 模型拟合\n",
    "clf.fit(X_train, y_train)\n",
    "# 模型预测\n",
    "y_pred = clf.predict(X_test)\n",
    "# 准确率评估\n",
    "accuracy = accuracy_score(y_test, y_pred)\n",
    "print (\"Accuracy: \", accuracy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[14:56:53] WARNING: C:/Users/Administrator/workspace/xgboost-win64_release_1.3.0/src/learner.cc:1061: Starting in XGBoost 1.3.0, the default evaluation metric used with the objective 'multi:softmax' was changed from 'merror' to 'mlogloss'. Explicitly set eval_metric if you'd like to restore the old behavior.\n",
      "Accuracy: 0.9666666666666667\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX8AAAEWCAYAAACOv5f1AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAHlhJREFUeJzt3XucVXW9//HXB0RuoygOY+CASIiYXOaghT1SHCRLLgKpP4JDAXrMQosoL9FFi3PyEafklKdfaZgKpCKSCWRopTBInTBEESzlok5HcZSLYM6IMDN8zh9rzbAZ5rJh9pq94ft+Ph77MXt/99p7vWfBvPfa37Vnjbk7IiISllbZDiAiIi1P5S8iEiCVv4hIgFT+IiIBUvmLiARI5S8iEiCVv0gdZnaXmd2S7RwiSTJ9zl8yxcxKgVOB6pThPu7+ZjOesxi4390Lm5fu6GRmc4E33P072c4ixxbt+UumXebueSmXIy7+TDCz47K5/uYws9bZziDHLpW/tAgzO9/M/sfMdpvZC/Eefc19V5nZS2b2npm9amZfjMc7Ao8D3cysPL50M7O5Zvb9lMcXm9kbKbdLzewbZrYeqDCz4+LHPWJm283sNTOb1kjW2ueveW4zu9nMtplZmZmNNbMRZrbJzN4xs2+lPPZ7ZvZrM1sYfz/PmdnAlPvPNrOSeDv8zcxG11nvnWa2zMwqgH8DJgI3x9/7b+PlZpjZK/Hz/93MPpPyHFPM7E9mdruZ7Yq/1+Ep93c2s/vM7M34/sUp940ys3Vxtv8xswFp/wPLUUflL4kzs9OA3wHfBzoDNwKPmFmXeJFtwCjgROAq4MdmNsjdK4DhwJtH8E5iAjASOAnYD/wWeAE4DRgGTDezT6f5XB8C2sWPvRW4G/gccC5wIXCrmfVKWX4MsCj+Xh8EFptZGzNrE+f4A1AAfAV4wMzOSnnsvwK3AScA84EHgB/G3/tl8TKvxOvtBMwE7jezrinPMRjYCOQDPwTuMTOL7/sV0AE4J87wYwAzGwTcC3wROAX4BbDUzNqmuY3kKKPyl0xbHO857k7Zq/wcsMzdl7n7fnf/I/AsMALA3X/n7q94ZCVROV7YzBz/7e6vu/se4KNAF3f/d3ff5+6vEhX4+DSfqxK4zd0rgYeISvUOd3/P3f8G/A1I3Ute6+6/jpf/L6IXjvPjSx4wK86xHHiM6IWqxhJ3/3O8nT6oL4y7L3L3N+NlFgKbgY+lLPIPd7/b3auBeUBX4NT4BWI48CV33+XulfH2BvgC8At3f8bdq919HrA3zizHoKN2PlRy1lh3f7LO2OnA/zOzy1LG2gArAOJpie8CfYh2SDoAG5qZ4/U66+9mZrtTxloDq9J8rp1xkQLsib++nXL/HqJSP2Td7r4/npLqVnOfu+9PWfYfRO8o6stdLzObBHwd6BkP5RG9INV4K2X978c7/XlE70Tecfdd9Tzt6cBkM/tKytjxKbnlGKPyl5bwOvArd/9C3TviaYVHgElEe72V8TuGmmmK+j6OVkH0AlHjQ/Usk/q414HX3P3MIwl/BLrXXDGzVkAhUDNd1d3MWqW8APQANqU8tu73e9BtMzud6F3LMOAv7l5tZus4sL0a8zrQ2cxOcvfd9dx3m7vflsbzyDFA0z7SEu4HLjOzT5tZazNrFx9ILSTau2wLbAeq4ncBn0p57NvAKWbWKWVsHTAiPnj5IWB6E+v/K/DP+CBw+zhDPzP7aMa+w4Oda2aXx580mk40fbIaeIbohevm+BhAMXAZ0VRSQ94GUo8ndCR6QdgO0cFyoF86ody9jOgA+s/N7OQ4w5D47ruBL5nZYIt0NLORZnZCmt+zHGVU/pI4d3+d6CDot4hK63XgJqCVu78HTAMeBnYRHfBcmvLYl4EFwKvxcYRuRActXwBKiY4PLGxi/dVEJVsEvAbsAH5JdMA0CUuAzxJ9P58HLo/n1/cBo4nm3XcAPwcmxd9jQ+4BPlJzDMXd/w7MBv5C9MLQH/jzYWT7PNExjJeJDrRPB3D3Z4nm/f9/nHsLMOUwnleOMvolL5EMMrPvAb3d/XPZziLSGO35i4gESOUvIhIgTfuIiARIe/4iIgHK2c/5n3TSSd67d+9sx2hURUUFHTt2zHaMRilj8+V6PlDGTDkWMq5du3aHu3dpcIEa7p6Tlz59+niuW7FiRbYjNEkZmy/X87krY6YcCxmBZz2NjtW0j4hIgFT+IiIBUvmLiARI5S8iEiCVv4hIgFT+IiIBUvmLiARI5S8iEiCVv4hIgFT+IiIBUvmLiARI5S8iEiCVv4hIgFT+IiIBUvmLiARI5S8iEiCVv4hIgFT+IiIBUvmLiARI5S8iEiCVv4hIgFT+IiIBUvmLiARI5S8iEiCVv4hIgFT+IiIBUvmLiARI5S8iEiCVv4hIgFT+IiIBUvmLiARI5S8iEiCVv4hIgFT+IiIBUvmLiARI5S8iEiCVv4hIgMzds52hXj169fZW4+7IdoxG3dC/itkbjst2jEYpY/Plej5QxkzJZMbSWSMz8jx1lZSUUFxc3OD9ZrbW3c9r6nm05y8ikpAPPviAj33sYwwcOJBzzjmH7373uwBceOGFFBUVUVRURLdu3Rg7diwAS5YsYcCAARQVFXHeeefxpz/9KbFsib0Mm9k0YCrQF9gQD5cDU939haTWKyKSK9q2bcvy5cvJy8ujsrKSCy64gOHDh7Nq1araZa644grGjBkDwLBhwxg9ejRmxvr16xk3bhwvv/xyItmSfA92HTAc6Aq85O67zGw4MAcYnOB6RURygpmRl5cHQGVlJZWVlZhZ7f3vvfcey5cv57777gOoXRagoqLioGUzLZFpHzO7C+gFLAUGu/uu+K7VQGES6xQRyUXV1dUUFRVRUFDAJZdcwuDBB/Z9H330UYYNG8aJJ5540Fjfvn0ZOXIk9957b2K5Ejvga2alwHnuviNl7Eagr7tf08BjrgWuBcjP73LurT+5O5FsmXJqe3h7T7ZTNE4Zmy/X84EyZkomM/Y/rdNBt8vLy7nllluYNm0aZ5xxBgDf+MY3GDFiBBdddNEhj3/hhReYP38+s2fPPuR5Ut8h1DV06NC0Dvi2WPmb2VDg58AF7r6zqcfr0z6ZoYzNl+v5QBkzJelP+8ycOZOOHTty4403snPnTvr06cPWrVtp165dvc9xxhlnsGbNGvLz82vHjqpP+5jZAOCXwJh0il9E5Fiwfft2du/eDcCePXt48skn6du3LwCLFi1i1KhRBxX/li1bqNkhf+6559i3bx+nnHJKItkSfxk2sx7Ab4DPu/umpNcnIpIrysrKmDx5MtXV1ezfv59x48YxatQoAB566CFmzJhx0PKPPPII8+fPp02bNrRv356FCxcmd9DX3RO5AKVAPtEe/y5gXXx5Np3H9+nTx3PdihUrsh2hScrYfLmez10ZM+VYyJhuxya25+/uPeOr18QXERHJEfoNXxGRAKn8RUQCpPIXEQmQyl9EJEAqfxGRAKn8RUQCpPIXEQmQyl9EJEAqfxGRAKn8RUQCpPIXEQmQyl9EJEAqfxGRAKn8RUQCpPIXEQmQyl9EJEAqfxGRAKn8RUQCpPIXEQmQyl9EJEAqfxGRAKn8RUQCpPIXEQmQyl9EJEAqfxGRAKn8RUQCpPIXEQmQyl9EJEAqfxGRAKn8RUQCpPIXEQmQyl9EJEAqfxGRAJm7ZztDvXr06u2txt2R7RiNuqF/FbM3HJftGI1SxubL9XygjJlSX8bSWSP54IMPGDJkCHv37qWqqoorr7ySmTNnMmXKFFauXEmnTp0AmDt3LkVFRQCUlJQwffp0Kisryc/PZ+XKlRnJWFJSQnFxcYP3m9ladz+vqedJ9F/CzKYBU4EPAa8D+4EqYLq7/ynJdYuIZErbtm1Zvnw5eXl5VFZWcsEFFzB8+HAAfvSjH3HllVcetPzu3bu57rrreOKJJ+jRowfbtm3LRuxGJf0yfB0wHNgOVLi7m9kA4GGgb8LrFhHJCDMjLy8PgMrKSiorKzGzBpd/8MEHufzyy+nRowcABQUFLZLzcCQ2529mdwG9gKXAF/zA/FJHIDfnmkREGlBdXU1RUREFBQVccsklDB48GIBvf/vbDBgwgK997Wvs3bsXgE2bNrFr1y6Ki4s599xzmT9/fjaj1yvROX8zKwXOc/cdZvYZ4AdAATDS3f9Sz/LXAtcC5Od3OffWn9ydWLZMOLU9vL0n2ykap4zNl+v5QBkzpb6M/U/rdNDt8vJybrnlFqZNm8aJJ55I586dqaysZPbs2XTr1o3Jkydzxx13sHHjRmbPns2+ffu4/vrr+cEPfkD37t2bnbG8vLz2XUh9hg4dmv05/1Tu/ijwqJkNAf4D+GQ9y8wB5kB0wPdoPDiUa5Sx+XI9HyhjptR7wHdi8SHLrV27lp07d3LVVVfVjh1//PHcfvvtFBcXs3r1agYOHFh7XGDp0qW0a9eu0QO16WrqgG+6Wvyjnu7+NPBhM8tv6XWLiByJ7du3s3v3bgD27NnDk08+Sd++fSkrKwPA3Vm8eDH9+vUDYMyYMaxatYqqqiref/99nnnmGc4+++ys5a/PYb8Mm9nJQHd3X38Yj+kNvBIf8B0EHA/sPNx1i4hkQ1lZGZMnT6a6upr9+/czbtw4Ro0axcUXX8z27dtxd4qKirjrrrsAOPvss7n00ksZMGAArVq14pprrql9YcgVaZW/mZUAo+Pl1wHbzWylu389zfVcAUwys0pgD/BZz9VfMBARqWPAgAE8//zzh4wvX768wcfcdNNN3HTTTUnGah53b/ICPB9/vQaYGV9fn85jj/TSp08fz3UrVqzIdoQmKWPz5Xo+d2XMlGMhI/Csp9Gx6c75H2dmXYFxwGMJvAaJiEgLSrf8/x34PdG8/Roz6wVsTi6WiIgkKa05f3dfBCxKuf0q0Ty+iIgchdLa8zezPmb2lJm9GN8eYGbfSTaaiIgkJd1pn7uBbwKVAB59zHN8UqFERCRZ6ZZ/B3f/a52xqkyHERGRlpFu+e8wsw8Tn5DNzK4EyhJLJSIiiUr3N3yvJzrnTl8z2wq8BkxMLJWIiCSqyfI3s1ZEZ+b8pJl1BFq5+3vJRxMRkaQ0Oe3j7vuBL8fXK1T8IiJHv3Tn/P9oZjeaWXcz61xzSTSZiIgkJt05/6vjr9enjDnRX+oSEZGjTLq/4XtG0kFERKTlpHtK50n1jbt77v1hShERaVK60z4fTbneDhgGPAeo/EVEjkLpTvt8JfW2mXUCfpVIIhERSdyR/g3f94EzMxlERERaTrpz/r8lPrUD0QvGR0g5xbOIiBxd0p3zvz3lehXwD3d/I4E8IiLSAtKd9hnh7ivjy5/d/Q0z+89Ek4mISGLSLf9L6hkbnskgIiLSchqd9jGzqcB1QC8zW59y1wnAn5MMJiIiyWlqzv9B4HHgB8CMlPH33P2dxFKJiEiiGi1/d38XeBeYAGBmBUS/5JVnZnnu/r/JRxQRkUxL9w+4X2Zmm4n+iMtKoJToHYGIiByF0j3g+33gfGBTfJK3YWjOX0TkqJVu+Ve6+06glZm1cvcVQFGCuUREJEHp/pLXbjPLA1YBD5jZNqJf9hIRkaNQunv+Y4jO5zMdeAJ4BbgsqVAiIpKsdM/qWWFmpwNnuvs8M+sAtE42moiIJCXdT/t8Afg18It46DRgcVKhREQkWelO+1wPfAL4J4C7bwYKkgolIiLJSrf897r7vpobZnYcB07xLCIiR5l0P+2z0sy+BbQ3s0uIzvfz2+RiwZ7KanrO+F2Sq2i2G/pXMUUZm+1IMpbOGplQGpEwpLvnPwPYDmwAvggsA76TVCiRdF199dUUFBTQr1+/g8Z/+tOfctZZZ3HOOedw8803A/DAAw9QVFRUe2nVqhXr1q3LRmyRrGvqrJ493P1/3X0/cHd8SZuZTQOmEv2x953ACKKPjE5x9+eOLLLIAVOmTOHLX/4ykyZNqh1bsWIFS5YsYf369bRt25Zt27YBMHHiRCZOnAjAhg0bGDNmDEVF+l1FCVNTe/61n+gxs0eO4PmvIyr8B4j+5u+ZwLXAnUfwXCKHGDJkCJ07dz5o7M4772TGjBm0bdsWgIKCQz+bsGDBAiZMmNAiGUVyUVPlbynXex3OE5vZXfFjlgKPAvM9sho4ycy6HlZSkTRt2rSJVatWMXjwYC666CLWrFlzyDILFy5U+UvQmjrg6w1cb5K7f8nMLgWGAnOB11PufoPodwXKUh9jZtcSvTMgP78Lt/bP7TNInNo+OliZy47VjCUlJbXX33rrLSoqKmrH3n33XTZs2MCsWbN4+eWXGT16NA8++CBm0b7M3//+d9ydHTt2HPQ8DSkvL09ruWxSxswIKWNT5T/QzP5J9A6gfXyd+La7+4lprsfqGTvkxcTd5wBzAHr06u2zN6T7YaTsuKF/FcrYfEeSsXRi8YHrpaV07NiR4uJo7KyzzmLatGkUFxczdOhQbr/9dvr160eXLl0AWLJkCddcc03t8k0pKSlJe9lsUcbMCCljo9M+7t7a3U909xPc/bj4es3tdIsfoj397im3C4E3jySwSFPGjh3L8uXLgWgKaN++feTn5wOwf/9+Fi1axPjx47MZUSTr0v2oZ3MtBSZZ5HzgXXcva+pBIk2ZMGECH//4x9m4cSOFhYXcc889XH311bz66qv069eP8ePHM2/evNopn6effprCwkJ69TqsQ1gix5yWmg9YRvSpny1EH/W8qqkHtG/Tmo05/os8JSUlB00/5KJjPeOCBQvqHb///vvrHS8uLmb16tVHtC6RY0mi5e/uPVNuXp/kukREJH0tNe0jIiI5ROUvIhIglb+ISIBU/iIiAVL5i4gESOUvIhIglb+ISIBU/iIiAVL5i4gESOUvIhIglb+ISIBU/iIiAVL5i4gESOUvIhIglb+ISIBU/iIiAVL5i4gESOUvIhIglb+ISIBU/iIiAVL5i4gESOUvIhIglb+ISIBU/iIiAVL5i4gESOUvIhIglb+ISIBU/iIiAVL5i4gESOUvIhIglb+ISIBU/iIiAVL5i4gE6LhsB2jInspqes74XbZjNOqG/lVMyfGMcy/tCMDVV1/NY489RkFBAS+++CIA77zzDp/97GcpLS2lZ8+ePPzww5x88sm4O1/96ldZtmwZHTp0YO7cuQwaNCib34aIZFhie/5mNs3MXjKzR8zsL2a218xuTGp90rgpU6bwxBNPHDQ2a9Yshg0bxubNmxk2bBizZs0C4PHHH2fz5s1s3ryZOXPmMHXq1GxEFpEEJTntcx0wApgKTANuT3Bd0oQhQ4bQuXPng8aWLFnC5MmTAZg8eTKLFy+uHZ80aRJmxvnnn8/u3bspKytr8cwikpxEyt/M7gJ6AUuBie6+BqhMYl1y5N5++226du0KQNeuXdm2bRsAW7dupXv37rXLFRYWsnXr1qxkFJFkJDLn7+5fMrNLgaHuviPdx5nZtcC1APn5Xbi1f1US8TLm1PbRvH8uKy8vp6SkBIC33nqLioqK2ttVVVW111Nv79ixg+eff56qquh727VrF2vXrqW8vDzxjLko1/OBMmZKSBlz6oCvu88B5gD06NXbZ2/IqXiHuKF/Fbmece6lHSkuLgagtLSUjh0P3D7ttNM466yz6Nq1K2VlZXTr1o3i4mIGDhxIfn5+7XIVFRWMHj269l1CppWUlNSuKxflej5QxkwJKaM+6hmw0aNHM2/ePADmzZvHmDFjasfnz5+Pu7N69Wo6deqUWPGLSHbk9m6rZMyECRNqp3QKCwuZOXMmM2bMYNy4cdxzzz306NGDRYsWATBixAiWLVtG79696dChA/fdd1+W04tIpiVe/mb2IeBZ4ERgv5lNBz7i7v9Met1ywIIFC+odf+qppw4ZMzN+9rOfJR1JRLIosfJ3954pNwsP9/Ht27Rm46yRmQuUgJKSEkonFmc7RqNy/eCViGSH5vxFRAKk8hcRCZDKX0QkQCp/EZEAqfxFRAKk8hcRCZDKX0QkQCp/EZEAqfxFRAKk8hcRCZDKX0QkQCp/EZEAqfxFRAKk8hcRCZDKX0QkQCp/EZEAqfxFRAKk8hcRCZDKX0QkQCp/EZEAqfxFRAKk8hcRCZDKX0QkQCp/EZEAqfxFRAKk8hcRCZDKX0QkQCp/EZEAqfxFRAKk8hcRCZDKX0QkQCp/EZEAqfxFRAKk8hcRCZDKX0QkQCp/EZEAqfxFRAKk8hcRCZC5e7Yz1MvM3gM2ZjtHE/KBHdkO0QRlbL5czwfKmCnHQsbT3b1LU09yXObyZNxGdz8v2yEaY2bPKmPz5XrGXM8HypgpIWXUtI+ISIBU/iIiAcrl8p+T7QBpUMbMyPWMuZ4PlDFTgsmYswd8RUQkObm85y8iIglR+YuIBCgny9/MLjWzjWa2xcxmZDsPgJmVmtkGM1tnZs/GY53N7I9mtjn+enILZ7rXzLaZ2YspY/Vmssh/x9t0vZkNymLG75nZ1nhbrjOzESn3fTPOuNHMPt1CGbub2Qoze8nM/mZmX43Hc2JbNpIvZ7ajmbUzs7+a2Qtxxpnx+Blm9ky8DRea2fHxeNv49pb4/p5ZzDjXzF5L2Y5F8XhWfmbidbc2s+fN7LH4dua3o7vn1AVoDbwC9AKOB14APpIDuUqB/DpjPwRmxNdnAP/ZwpmGAIOAF5vKBIwAHgcMOB94JosZvwfcWM+yH4n/vdsCZ8T/D1q3QMauwKD4+gnApjhLTmzLRvLlzHaMt0VefL0N8Ey8bR4GxsfjdwFT4+vXAXfF18cDC1vg37mhjHOBK+tZPis/M/G6vw48CDwW3874dszFPf+PAVvc/VV33wc8BIzJcqaGjAHmxdfnAWNbcuXu/jTwTpqZxgDzPbIaOMnMumYpY0PGAA+5+153fw3YQvT/IVHuXubuz8XX3wNeAk4jR7ZlI/ka0uLbMd4W5fHNNvHFgYuBX8fjdbdhzbb9NTDMzCxLGRuSlZ8ZMysERgK/jG8bCWzHXCz/04DXU26/QeP/0VuKA38ws7Vmdm08dqq7l0H0AwoUZC3dAQ1lyrXt+uX4rfS9KdNlWc8Yv23+F6K9wpzblnXyQQ5tx3iqYh2wDfgj0TuO3e5eVU+O2ozx/e8Cp7R0Rnev2Y63xdvxx2bWtm7GevIn6SfAzcD++PYpJLAdc7H863vVyoXPo37C3QcBw4HrzWxItgMdplzarncCHwaKgDJgdjye1Yxmlgc8Akx39382tmg9Y4nnrCdfTm1Hd6929yKgkOidxtmN5MiJjGbWD/gm0Bf4KNAZ+Ea2MprZKGCbu69NHW4kxxFnzMXyfwPonnK7EHgzS1lqufub8ddtwKNE/7nfrnkbGH/dlr2EtRrKlDPb1d3fjn8I9wN3c2BKImsZzawNUbE+4O6/iYdzZlvWly8Xt2OcazdQQjRPfpKZ1ZxDLDVHbcb4/k6kPz2YyYyXxtNq7u57gfvI7nb8BDDazEqJprwvJnonkPHtmIvlvwY4Mz66fTzRQYyl2QxkZh3N7ISa68CngBfjXJPjxSYDS7KT8CANZVoKTIo/wXA+8G7NlEZLqzNv+hmibQlRxvHxJxjOAM4E/toCeQy4B3jJ3f8r5a6c2JYN5cul7WhmXczspPh6e+CTRMcmVgBXxovV3YY12/ZKYLnHRy1bOOPLKS/wRjSXnrodW/Rnxt2/6e6F7t6TqPuWu/tEktiOLXHk+nAvREfZNxHNGX47B/L0Ivr0xAvA32oyEc2tPQVsjr92buFcC4je7lcS7QH8W0OZiN4e/izephuA87KY8VdxhvXxf96uKct/O864ERjeQhkvIHqrvB5YF19G5Mq2bCRfzmxHYADwfJzlReDWeLwX0QvPFmAR0DYebxff3hLf3yuLGZfH2/FF4H4OfCIoKz8zKXmLOfBpn4xvR53eQUQkQLk47SMiIglT+YuIBEjlLyISIJW/iEiAVP4iIgHK5T/gLpIIM6sm+uhejbHuXpqlOCJZoY96SnDMrNzd81pwfcf5gfOyiOQETfuI1GFmXc3s6fjc7i+a2YXx+KVm9lx8Pvin4rHOZrY4PinYajMbEI9/z8zmmNkfgPnxCcV+ZGZr4mW/mMVvUUTTPhKk9vGZHQFec/fP1Ln/X4Hfu/ttZtYa6GBmXYjOnzPE3V8zs87xsjOB5919rJldDMwnOtEawLnABe6+Jz4T7Lvu/tH4rJF/NrM/eHTKZZEWp/KXEO3x6MyODVkD3BufTG2xu68zs2Lg6Zqydveak2ddAFwRjy03s1PMrFN831J33xNf/xQwwMxqzs/SieicOyp/yQqVv0gd7v50fMrukcCvzOxHwG7qP1VuY6fUraiz3Ffc/fcZDStyhDTnL1KHmZ1OdE71u4nOpjkI+AtwUXyWTFKmfZ4GJsZjxcAOr/9vAfwemBq/m8DM+sRniBXJCu35ixyqGLjJzCqBcmCSu2+P5+1/Y2atiM7tfwnR39G9z8zWA+9z4PS6df0S6Ak8F586eDst/Gc/RVLpo54iIgHStI+ISIBU/iIiAVL5i4gESOUvIhIglb+ISIBU/iIiAVL5i4gE6P8Ap5spSUxkDO0AAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "import xgboost as xgb\n",
    "from xgboost import plot_importance\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "# 设置模型参数\n",
    "params = {\n",
    "    'booster': 'gbtree',\n",
    "    'objective': 'multi:softmax',   \n",
    "    'num_class': 3,     \n",
    "    'gamma': 0.1,\n",
    "    'max_depth': 2,\n",
    "    'lambda': 2,\n",
    "    'subsample': 0.7,\n",
    "    'colsample_bytree': 0.7,\n",
    "    'min_child_weight': 3,\n",
    "    'eta': 0.001,\n",
    "    'seed': 1000,\n",
    "    'nthread': 4,\n",
    "}\n",
    "\n",
    "\n",
    "dtrain = xgb.DMatrix(X_train, y_train)\n",
    "num_rounds = 200\n",
    "model = xgb.train(params, dtrain, num_rounds)\n",
    "# 对测试集进行预测\n",
    "dtest = xgb.DMatrix(X_test)\n",
    "y_pred = model.predict(dtest)\n",
    "\n",
    "# 计算准确率\n",
    "accuracy = accuracy_score(y_test, y_pred)\n",
    "print (\"Accuracy:\", accuracy)\n",
    "# 绘制特征重要性\n",
    "plot_importance(model)\n",
    "plt.show();"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
